import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt


# 1D迷宫环境，目标是从位置0移动到位置10
class SimpleMazeEnv:
    def __init__(self):
        self.state = 0  # 初始位置
        self.target = 10  # 目标位置
        self.max_steps = 20  # 最大步数

    def reset(self):
        self.state = 0
        return self.state

    def step(self, action):
        if action == 0:  # 向左移动
            self.state = max(0, self.state - 1)
        elif action == 1:  # 向右移动
            self.state = min(self.target, self.state + 1)

        # 计算奖励，靠近目标位置时奖励更高
        reward = -abs(self.state - self.target)  # 离目标越远奖励越低
        done = (self.state == self.target)  # 到达目标时结束
        return self.state, reward, done


# 策略网络
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return self.softmax(x)


# 策略梯度算法（REINFORCE）
def reinforce(env, policy, optimizer, episodes=1000, gamma=0.99):
    episode_rewards = []
    best_reward = -float('inf')
    best_path = []

    for episode in range(episodes):
        state = env.reset()
        state = torch.tensor([state], dtype=torch.float32)
        done = False
        rewards = []
        log_probs = []
        path = []  # 记录当前回合的路径

        while not done:
            # 选择动作
            action_probs = policy(state)
            dist = torch.distributions.Categorical(action_probs)
            action = dist.sample()

            # 执行动作并观察结果
            next_state, reward, done = env.step(action.item())
            next_state = torch.tensor([next_state], dtype=torch.float32)

            # 保存奖励和动作的log概率
            rewards.append(reward)
            log_probs.append(dist.log_prob(action))
            path.append(state.item())  # 记录当前位置

            state = next_state

        # 计算回报
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)

        # 计算损失并更新模型
        returns = torch.tensor(returns, dtype=torch.float32)
        log_probs = torch.stack(log_probs)
        loss = -torch.sum(log_probs * returns)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_reward = sum(rewards)
        episode_rewards.append(total_reward)

        if total_reward > best_reward:
            best_reward = total_reward
            best_path = path

        if (episode + 1) % 100 == 0:
            print(f"Episode {episode + 1}, Total Reward: {total_reward}, Best Reward: {best_reward}")

    return episode_rewards, best_path


# 初始化环境和模型
env = SimpleMazeEnv()
input_dim = 1  # 状态是一个标量
output_dim = 2  # 动作是向左或向右
policy = PolicyNetwork(input_dim, output_dim)
optimizer = optim.Adam(policy.parameters(), lr=0.001)

# 训练模型
episode_rewards, best_path = reinforce(env, policy, optimizer, episodes=1000)

# 可视化训练结果
plt.figure(figsize=(12, 6))

# 绘制奖励曲线
plt.subplot(1, 2, 1)
plt.plot(episode_rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Training Progress')

# 绘制最优路径图
plt.subplot(1, 2, 2)
plt.plot(best_path, marker='o', markersize=5, label="Best Path")
for i, coord in enumerate(best_path):
    plt.text(i, coord, f"({i}, {coord})", fontsize=8)  # 显示坐标
plt.xlabel('Steps')
plt.ylabel('State')
plt.title('Best Path Taken')
plt.legend()

plt.tight_layout()
plt.show()
